/*
 * Copyright (c) 2020-2021. the original authors and DEPSEA.ORG
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.depsea.log.aop.advice;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.depsea.log.exception.LogPointAspectException;
import org.depsea.log.getter.ApiNameGetter;
import org.depsea.log.getter.UserGetter;
import org.depsea.log.handler.RequestPointHandler;
import org.depsea.log.point.RequestPoint;
import org.springframework.aop.AfterReturningAdvice;
import org.springframework.aop.MethodBeforeAdvice;
import org.springframework.aop.ThrowsAdvice;
import org.springframework.beans.BeansException;
import org.springframework.cloud.sleuth.Tracer;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.lang.NonNull;
import org.springframework.util.CollectionUtils;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.util.*;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

/**
 * @author jaune
 * @since 1.0.0
 */
@Slf4j
public class RequestAdvice implements AfterReturningAdvice, ThrowsAdvice, MethodBeforeAdvice, ApplicationContextAware {

    /**
     * 线程池，用于执行异步处理
     */
    private final ExecutorService executorService = Executors.newCachedThreadPool();

    /**
     * 是否包含请求头信息
     */
    private boolean includeRequestHeaders = true;

    /**
     * 是否包含应答头信息
     */
    private boolean includeResponseHeaders = true;

    private ApiNameGetter apiNameGetter;

    private UserGetter userGetter;

    /**
     * 需要 spring-cloud-starter-sleuth 支持，用于获取链路追踪信息
     */
    private Tracer tracer;

    /**
     * 用于记录耗时
     */
    private final ThreadLocal<Long> timeConsumingThreadLocal = new ThreadLocal<>();

    /**
     * JSON序列化
     */
    private final ObjectMapper objectMapper = new ObjectMapper();

    /**
     * 请求信息处理
     */
    private List<RequestPointHandler> requestPointHandlers = new ArrayList<>();

    /**
     * Spring 上下文
     */
    private ApplicationContext applicationContext;

    public void setApiNameGetter(ApiNameGetter apiNameGetter) {
        this.apiNameGetter = apiNameGetter;
    }

    public void setTracer(Tracer tracer) {
        this.tracer = tracer;
    }

    public void setIncludeRequestHeaders(boolean includeRequestHeaders) {
        this.includeRequestHeaders = includeRequestHeaders;
    }

    public void setIncludeResponseHeaders(boolean includeResponseHeaders) {
        this.includeResponseHeaders = includeResponseHeaders;
    }

    public void setPointHandlers(List<RequestPointHandler> requestPointHandlers) {
        this.requestPointHandlers = requestPointHandlers;
    }

    public void addPointHandlers(RequestPointHandler... requestPointHandlers) {
        this.requestPointHandlers.addAll(Arrays.stream(requestPointHandlers).toList());
    }

    public void setUserGetter(UserGetter userGetter) {
        this.userGetter = userGetter;
    }

    @Override
    public void setApplicationContext(@NonNull ApplicationContext applicationContext) throws BeansException {
        this.applicationContext = applicationContext;
    }

    @Override
    public void before(@NonNull Method method, @NonNull Object[] args, Object target) throws Throwable {
        // record start time
        timeConsumingThreadLocal.set(System.currentTimeMillis());
    }

    @Override
    public void afterReturning(Object returnValue, @NonNull Method method, @NonNull Object[] args, Object target) throws Throwable {
        try {
            var requestPoint = this.createRequestPoint(method, args);
            if (returnValue != null) {
                requestPoint.setReturnValue(this.objectMapper.writeValueAsString(returnValue));
            }
            this.doRequestHandler(requestPoint);
        } catch (Exception ex) {
            log.warn("处理失败：{}", ex.getMessage());
        }
    }

    public void afterThrowing(Method method, Object[] args, Object target, Exception ex) {
        try {
            var requestPoint = this.createRequestPoint(method, args);
            requestPoint.setError(true);
            requestPoint.setExceptionName(ex.getClass().getName());
            requestPoint.setExceptionStack(ExceptionUtils.getStackTrace(ex));
            requestPoint.setErrorMessage(ex.getMessage());
            this.doRequestHandler(requestPoint);
        } catch (Exception e) {
            log.warn("处理失败：{}", e.getMessage());
        }
    }

    private RequestPoint createRequestPoint(@NonNull Method method, @NonNull Object[] args) throws JsonProcessingException {
        var request = this.getRequest();
        var response = this.getResponse();

        var requestPoint = new RequestPoint();

        if (this.includeRequestHeaders) {
            requestPoint.setRequestHeaders(this.objectMapper.writeValueAsString(this.getRequestHeaders(request)));
        }
        if (this.includeResponseHeaders && response != null) {
            requestPoint.setResponseHeaders(this.objectMapper.writeValueAsString(this.getResponseHeaders(response)));
        }

        if (this.apiNameGetter != null) {
            requestPoint.setModuleName(this.apiNameGetter.getModuleName(method.getDeclaringClass()));
            requestPoint.setApiName(this.apiNameGetter.getOperationName(method));
        }

        if (this.userGetter != null) {
            requestPoint.setUserId(this.userGetter.getUserId());
            requestPoint.setName(this.userGetter.getName());
        }

        requestPoint.setApplicationName(this.getApplicationName());
        requestPoint.setClazz(method.getDeclaringClass().getName());
        requestPoint.setMethodName(method.getName());
        requestPoint.setError(false);
        requestPoint.setTimestamp(new Date());

        if (this.tracer != null) {
            var span = tracer.currentSpan();
            if (span != null) {
                requestPoint.setSpanId(span.context().spanId());
                requestPoint.setTraceId(span.context().traceId());
                requestPoint.setParentId(span.context().parentId());
            }
        }

        requestPoint.setSchema(request.getScheme().toUpperCase());
        requestPoint.setRequestMethod(request.getMethod().toUpperCase());
        requestPoint.setRequestUri(request.getRequestURI());

        Map<String, Object> methodParams = new HashMap<>();
        Parameter[] parameters = method.getParameters();
        for (var i = 0; i < parameters.length; i++) {
            var parameter = parameters[i];
            methodParams.put(parameter.getName(), args[i]);
        }

        requestPoint.setMethodParameterMap(objectMapper.writeValueAsString(methodParams));
        requestPoint.setRequestParameterMap(objectMapper.writeValueAsString(request.getParameterMap()));
        if (response != null) {
            requestPoint.setResponseStatus(response.getStatus());
        }

        long start = this.timeConsumingThreadLocal.get();
        long end = System.currentTimeMillis();
        requestPoint.setTimeConsuming(end - start);
        this.timeConsumingThreadLocal.remove();
        return requestPoint;
    }

    private String getApplicationName() {
        return this.applicationContext.getEnvironment().getProperty("spring.application.name");
    }

    private Map<String, String> getRequestHeaders(HttpServletRequest request) {
        Map<String, String> headers = new HashMap<>();
        Enumeration<String> headerNames = request.getHeaderNames();
        while (headerNames.hasMoreElements()) {
            String headerName = headerNames.nextElement();
            headers.put(headerName, request.getHeader(headerName));
        }
        return headers;
    }

    private Map<String, String> getResponseHeaders(HttpServletResponse response) {
        Map<String, String> headers = new HashMap<>();
        Collection<String> headerNames = response.getHeaderNames();
        if (CollectionUtils.isEmpty(headerNames)) {
            return headers;
        }
        for (String headerName : headerNames) {
            headers.put(headerName, response.getHeader(headerName));
        }
        return headers;
    }

    private HttpServletRequest getRequest() {
        Optional<RequestAttributes> requestAttributesOptional = Optional.ofNullable(RequestContextHolder.getRequestAttributes());
        if (requestAttributesOptional.isPresent()) {
            var servletRequestAttributes = (ServletRequestAttributes) requestAttributesOptional.get();
            return servletRequestAttributes.getRequest();
        } else {
            throw new LogPointAspectException("Could not get the HttpServletRequest from the spring webmvc context.");
        }
    }

    private HttpServletResponse getResponse() {
        Optional<RequestAttributes> requestAttributesOptional = Optional.ofNullable(RequestContextHolder.getRequestAttributes());
        if (requestAttributesOptional.isPresent()) {
            var servletRequestAttributes = (ServletRequestAttributes) requestAttributesOptional.get();
            return servletRequestAttributes.getResponse();
        } else {
            throw new LogPointAspectException("Could not get the HttpServletResponse from the spring webmvc context.");
        }
    }

    private void doRequestHandler(RequestPoint requestPoint) {
        if (!CollectionUtils.isEmpty(this.requestPointHandlers)) {
            for (RequestPointHandler requestPointHandler : this.requestPointHandlers) {
                try {
                    if (requestPointHandler.isAsync()) {
                        this.executorService.submit(() -> requestPointHandler.handle(requestPoint));
                    } else {
                        requestPointHandler.handle(requestPoint);
                    }
                } catch (Exception ex) {
                    log.warn("Execute request point handler [{}] fail. Cause: {}",
                            requestPointHandler.getClass().getName(), ex.getMessage(), ex);
                }
            }
        }
    }
}